<?php

namespace DataCube\DataCubeAggregation\AI_Toolkit\Classification;

use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\TrainerInterface;
use DataCube\DataCubeAggregation\Exception\CustomException;
use Rubix\ML\Classifiers\LogisticRegression as RubixLogisticRegression;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Kernels\Distance\Manhattan;
use Symfony\Component\OptionsResolver\OptionsResolver;

class LogisticRegression extends BaseClassifier implements TrainerInterface
{
    public function __construct(array $options = [])
    {
        $resolver = new OptionsResolver();
        $this->configureOptions($resolver);

        $this->options = $resolver->resolve($options);
        $this->classifier = new RubixLogisticRegression(
            $this->options['batchSize'],
            $this->options['optimizer'],
            $this->options['l2Penalty'],
            $this->options['epochs'],
            $this->options['minChange'],
            $this->options['window'],
            $this->options['costFn'],
        );
    }

    /**
     *int $batchSize = 128,
    ?Optimizer $optimizer = null,
    float $l2Penalty = 1e-4,
    int $epochs = 1000,
    float $minChange = 1e-4,
    int $window = 5,
    ?ClassificationLoss $costFn = null
     * @param OptionsResolver $resolver
     * @return void
     */
    public function configureOptions(OptionsResolver $resolver): void
    {
        $resolver->setDefaults([
            'batchSize' => 64,
            'optimizer' => new \Rubix\ML\NeuralNet\Optimizers\Adam(0.001),
            'l2Penalty' => 1e-4,
            'epochs' => 100,
            'minChange' => 1e-4,
            'window' => 5,
            'costFn' => new \Rubix\ML\NeuralNet\CostFunctions\CrossEntropy(),
        ]);
    }

    public function train(array $samples, array $targets): void
    {
        try {
            $this->classifier->train(new Labeled($samples, $targets));
        } catch (\Exception $e) {
        }
    }

    public function predict(array $samples)
    {
        try {
            return $this->classifier->predict(new Unlabeled([$samples]));
        } catch (\Exception $e) {
            throw new CustomException('Could not predict');
        }
    }

    public function step()
    {
        return $this->classifier->step();
    }

    public function losses()
    {
        return $this->classifier->losses();
    }

    public function network()
    {
        return $this->classifier->network();
    }
}